"""
------------------------------------------------------------------------------------------------------------------------
wfeatures.py
Copyright (C) 2019-22 - NFStream Developers
This file is part of NFStream, a Flexible Network Data Analysis Framework (https://www.nfstream.org/).
NFStream is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
version.
NFStream is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with NFStream.
If not, see <http://www.gnu.org/licenses/>.
------------------------------------------------------------------------------------------------------------------------
"""

import numpy as np
import numpy.typing as npt
import pywt

from nfstream import NFPlugin


class WFPlugin(NFPlugin):
    """Wavelet-based Features plugin. This plugin attempts to recreate wavelet-based features from [1].

    Features are calculated from `ip_size`, that is binned on packet timestamps into timeseries
    of len 2**`levels` (spanning `active_timeout`).

    Generated features will have names like:

    - For forward traffic: `udps.f_p_k_*`, `udps.f_u_k_*`, `udps.f_sigma_k_*`, `udps.f_S_k_*`
    - For backward traffic: `udps.b_p_k_*`, `udps.b_u_k_*`, `udps.b_sigma_k_*`, `udps.b_S_k_*`

    where `*` is a number from 0 to `levels`. So for given `levels`, `8*(levels + 1)` features are
    calculated.

    Plugin arguments:
        active_timeout (int):
            Active timeout (same as in `active_timeout`)
        levels (int):
            Levels of wavelets to calculate.


    [1]: Bujlow, T., Carela-Español, V. & Barlet-Ros, P. Extended Independent Comparison of
        Popular Deep Packet Inspection (DPI) Tools for Traffic Classification. 460.
    """

    def on_init(self, packet, flow):
        assert hasattr(self, "levels")
        assert hasattr(self, "active_timeout")

        # Pywt requires vector of length 2**level as input
        # set nbins to that number:
        self.nbins = 2**self.levels

        # Given `active_timeout` as max flow length calculate size of each bin:
        self.bin_size = self.active_timeout / 2**self.levels * 1000

        # Reserve a empty vectors for data i forward and backward direction
        flow.udps.forward = np.zeros(self.nbins).tolist()
        flow.udps.backward = np.zeros(self.nbins).tolist()

        # Save first packet timestamp that will be used to calc index
        # of time series bin
        flow.udps.first_packet_timestamp = packet.time  # timestamp in ms

    def on_update(self, packet, flow):
        # Calculate time in ms from first packet
        mstime_since_first_packet: int = packet.time - flow.udps.first_packet_timestamp

        # Calculate index of bin to put data into
        ibin, _ = divmod(mstime_since_first_packet, self.bin_size)

        # Put ip_size into timeseries depending on direction
        if packet.direction == 0:
            # src to dest
            flow.udps.forward[int(ibin)] += packet.ip_size
        else:
            # dst to src
            flow.udps.backward[int(ibin)] += packet.ip_size

    def on_expire(self, flow):
        # Calculate and add forward features
        f_p_k, f_u_k, f_sigma_k, f_S_k = self.calculate_wavelet_features(
            data=flow.udps.forward, level=self.levels
        )
        self.add_attrs_from_list(flow, f_p_k, "f_p_k")
        self.add_attrs_from_list(flow, f_u_k, "f_u_k")
        self.add_attrs_from_list(flow, f_sigma_k, "f_sigma_k")
        self.add_attrs_from_list(flow, f_S_k, "f_S_k")

        # Calculate and add backward features
        b_p_k, b_u_k, b_sigma_k, b_S_k = self.calculate_wavelet_features(
            data=flow.udps.backward, level=self.levels
        )
        self.add_attrs_from_list(flow, b_p_k, "b_p_k")
        self.add_attrs_from_list(flow, b_u_k, "b_u_k")
        self.add_attrs_from_list(flow, b_sigma_k, "b_sigma_k")
        self.add_attrs_from_list(flow, b_S_k, "b_S_k")

        # Delete temporary variables
        del flow.udps.forward
        del flow.udps.backward
        del flow.udps.first_packet_timestamp

    @staticmethod
    def add_attrs_from_list(flow, values, attr):
        """Adds new attributes to flow.udps from a list of `values`. Name of
        attributes are generated by adding `attr` and `number`
        """
        for i, value in enumerate(values):
            setattr(flow.udps, f"{attr}_{i}", value)

    @staticmethod
    def calculate_wavelet_features(
        data: npt.NDArray[np.float_], level
    ) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray]:
        """Calculate wavelet features based on [1].

        Args:
            data
                Data. Will be padded to 2**level.

        Returns:
            4 vectors of shape: (level+1, ), these vectors are:
            - `p_k` (Relative wavelet energy vector),
            - `u_k` (Absolute mean of detail coefficients),
            - `sigma_k` (Std. deviation of detail coefficients),
            - `S_k` (Shannon entropy)
        """
        data = np.pad(data, (0, 2**level - len(data)))
        d = np.array(pywt.swt(data, "haar", level=level, trim_approx=True)).T

        E_k = np.sum(np.power(np.abs(d), 2), axis=0)
        E_total = np.sum(E_k, axis=0)

        # Relarive wavelet energy
        p_k = E_k / (E_total + 1e-7)
        p_n = np.power(d, 2) / (E_k + 1e-7)
        # Shannon entropy
        S_k = -np.sum(p_n * np.log(p_n + 1e-7), axis=0)
        # Absolute mean of coefficients
        u_k = np.mean(np.abs(d), axis=0)
        # Std. deviation of coeficcients
        sigma_k = np.std(d, axis=0)
        return (
            p_k,
            u_k,
            sigma_k,
            S_k,
        )
